{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "from torchvision import datasets,transforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LeNet5(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(LeNet5,self).__init__()\n",
    "        \n",
    "        self.conv_unit = nn.Sequential(\n",
    "            nn.Conv2d(3,16,kernel_size=5,stride=1,padding=0),\n",
    "            nn.MaxPool2d(kernel_size=2,stride=2,padding=0),\n",
    "            nn.Conv2d(16,32,kernel_size=5,stride=1,padding=0),\n",
    "            nn.MaxPool2d(kernel_size=2,stride=2,padding=0),   \n",
    "        )\n",
    "        self.fc_unit = nn.Sequential(\n",
    "            nn.Linear(32*5*5,120),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Linear(120,84),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Linear(84,32),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Linear(32,10),\n",
    "            nn.ReLU(inplace=True),\n",
    "        )\n",
    "        \n",
    "    def forward(self,x):\n",
    "        batch_size = x.size(0)\n",
    "        x = self.conv_unit(x)\n",
    "        x = x.view(batch_size,32*5*5)\n",
    "        logits = self.fc_unit(x)\n",
    "        return logits\n",
    "    \n",
    "# def main():\n",
    "#     model = LeNet5()\n",
    "#     x = torch.randn(2,3,32,32)\n",
    "#     logits = model(x)\n",
    "#     print(logits.shape)\n",
    "    \n",
    "# if '__main__'==__name__:\n",
    "#     main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class ResnetBlk(nn.Module):\n",
    "    def __init__(self,ch_in,ch_out):\n",
    "        super(ResnetBlk,self).__init__()\n",
    "        self.unit = nn.Sequential(\n",
    "            nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1),\n",
    "            nn.BatchNorm2d(ch_out),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1),\n",
    "            nn.BatchNorm2d(ch_out),\n",
    "        )\n",
    "        self.extra = nn.Sequential()\n",
    "        if ch_out!=ch_in:\n",
    "            self.extra = nn.Sequential(\n",
    "                nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1),\n",
    "                nn.BatchNorm2d(ch_out),\n",
    "            )\n",
    "            \n",
    "    def forward(self,x):\n",
    "        out = self.unit(x)\n",
    "        return self.extra(x) + out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ResNet18(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(ResNet18,self).__init__()\n",
    "        self.unit = nn.Sequential(\n",
    "            nn.Conv2d(3,16,stride=3,kernel_size=3,padding=0),\n",
    "            nn.BatchNorm2d(16),\n",
    "        )\n",
    "        self.bltk1=ResnetBlk(16,32)\n",
    "        self.bltk2=ResnetBlk(32,64)\n",
    "        self.bltk3=ResnetBlk(64,128)\n",
    "        self.bltk4=ResnetBlk(128,256)\n",
    "        self.outlayer=nn.Linear(256*10*10,10)\n",
    "        \n",
    "    def forward(self,x):\n",
    "        x = F.relu(self.unit(x))\n",
    "        x = self.bltk1(x)\n",
    "        x = self.bltk2(x)\n",
    "        x = self.bltk3(x)\n",
    "        x = self.bltk4(x)\n",
    "        x = x.view(x.size(0),-1)\n",
    "        x = self.outlayer(x)\n",
    "        \n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2, 128, 32, 32])\n",
      "resnet: torch.Size([2, 10])\n"
     ]
    }
   ],
   "source": [
    "# def main():\n",
    "#     blk = ResnetBlk(64,128)\n",
    "#     tmp = torch.randn(2,64,32,32)\n",
    "#     out = blk(tmp)\n",
    "#     print(out.shape)\n",
    "    \n",
    "#     model = ResNet18()\n",
    "#     tmp = torch.randn(2,3,32,32)\n",
    "#     out = model(tmp)\n",
    "#     print('resnet:',out.shape)\n",
    "    \n",
    "\n",
    "# if __name__=='__main__':\n",
    "#     main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def main():\n",
    "    batch_size = 32\n",
    "    cifar_train = datasets.CIFAR10('cifar',True,transforms.Compose([\n",
    "        transforms.Resize((32,32)),\n",
    "        transforms.ToTensor()\n",
    "    ]),download = True)\n",
    "    cifar_train = torch.utils.data.DataLoader(cifar_train,batch_size=batch_size,shuffle=True)\n",
    "    cifar_test = datasets.CIFAR10('cifar',False,transforms.Compose([\n",
    "        transforms.Resize((32,32)),\n",
    "        transforms.ToTensor()\n",
    "    ]),download = True)\n",
    "    cifar_test = torch.utils.data.DataLoader(cifar_test,batch_size=batch_size,shuffle=True)\n",
    "    x,label = iter(cifar_train).next()\n",
    "#     torch.Size([32, 3, 32, 32])\n",
    "#     torch.Size([32])\n",
    "    device = torch.device('cuda:0')\n",
    "    model = ResNet18().to(device)\n",
    "    criterion = nn.CrossEntropyLoss().to(device)\n",
    "    optimizer = optim.Adam(model.parameters(),lr = 3e-4)\n",
    "    \n",
    "    for epoch in range(1000):\n",
    "        model.train()\n",
    "        for idx,(train,target) in enumerate(cifar_train):\n",
    "            train,target = train.to(device),target.to(device)\n",
    "            logits = model(train)\n",
    "            loss = criterion(logits,target)\n",
    "            optim.zero_grad()\n",
    "            loss.backward()\n",
    "            optim.step()\n",
    "        print(epoch,'loss:',loss.item())\n",
    "        model.eval()\n",
    "        with torch.no_grad():\n",
    "            total_correct = 0\n",
    "            total_num = 0\n",
    "            for x,label in cifar_test:\n",
    "                #[b,3,32,32]\n",
    "                #[b]\n",
    "                x,label = x.to(device),label.to(device)\n",
    "                #[b,10]\n",
    "                logits = model(x)\n",
    "                pred = logits.argmax(dim=1)\n",
    "                total_correct += torch.eq(pred,target).float().sum().item()\n",
    "                total_num += x.size(0)\n",
    "            acc = total_correct / total_num\n",
    "            print(epoch, 'acc:', acc)\n",
    "\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "Error attempting to use dtype torch.float32 with layout torch.strided and device type CUDA.  Torch not compiled with CUDA enabled.\n",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-37-58ca95c5b364>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m<ipython-input-36-4ada72d7259c>\u001b[0m in \u001b[0;36mmain\u001b[0;34m()\u001b[0m\n\u001b[1;32m     15\u001b[0m \u001b[0;31m#     torch.Size([32])\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     16\u001b[0m     \u001b[0mdevice\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'cuda:0'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 17\u001b[0;31m     \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mResNet18\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     18\u001b[0m     \u001b[0mcriterion\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mCrossEntropyLoss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     19\u001b[0m     \u001b[0moptimizer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moptim\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAdam\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mlr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m3e-4\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mto\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    391\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    392\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 393\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    394\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    395\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mregister_backward_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m    174\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    175\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 176\u001b[0;31m             \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    177\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    178\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mparam\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_parameters\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m    174\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    175\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 176\u001b[0;31m             \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    177\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    178\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mparam\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_parameters\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m    180\u001b[0m                 \u001b[0;31m# Tensors stored in modules are graph leaves, and we don't\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    181\u001b[0m                 \u001b[0;31m# want to create copy nodes, so we have to unpack the data.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 182\u001b[0;31m                 \u001b[0mparam\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    183\u001b[0m                 \u001b[0;32mif\u001b[0m \u001b[0mparam\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_grad\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    184\u001b[0m                     \u001b[0mparam\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_grad\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_grad\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m<lambda>\u001b[0;34m(t)\u001b[0m\n\u001b[1;32m    391\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    392\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 393\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    394\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    395\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mregister_backward_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: Error attempting to use dtype torch.float32 with layout torch.strided and device type CUDA.  Torch not compiled with CUDA enabled.\n"
     ]
    }
   ],
   "source": [
    "main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
